import torch
import math


def total_loss(data, label):
    """总损失
    :param data: x,y,w,h. Nx4x5x5(批大小,高,宽,预测结果)
    :param label: x,y,w,h
    """
    data[..., :4] *= 128
    conf_loss = confident_loss(data, label)
    loc_loss = bounding_box_loss(data, label)
    return conf_loss + loc_loss


def confident_loss(data, label):
    """置信度损失"""
    batch_size = data.shape[0]
    data_conf = data[..., 4].flatten()
    data_box = data[..., :4].reshape(-1, 4)
    label_box = label.repeat(20, 1, 1).permute(1, 0, 2).reshape(-1, 4)
    iou = box_iou(data_box, label_box)
    loss = (iou - data_conf) ** 2
    return sum(loss)/batch_size


def bounding_box_loss(data, label):
    """边界框回归损失"""
    batch_size = data.shape[0]
    label_x = (label[:, 0] // 128).int()
    label_y = (label[:, 1] // 128).int()
    data = torch.stack([data[i, label_y[i], label_x[i], :4] for i in range(batch_size)])  # Nx4
    loss = 1 - box_ciou(data, label)
    return sum(loss)/batch_size


def box_iou(b1, b2):
    """x,y,w,h
    :param b1: shape=(batch, 4)
    :param b2: shape=(batch, 4)
    :return iou: shape=(batch, 1)
    """
    # 求出预测框和真实框的左上角右下角
    b1_xy, b2_xy = b1[:, :2], b2[:, :2]
    b1_wh, b2_wh = b1[:, 2:4], b2[:, 2:4]
    b1_wh_half, b2_wh_half = b1_wh/2, b2_wh/2
    b1_x1y1, b2_x1y1 = b1_xy-b1_wh_half, b2_xy-b2_wh_half
    b1_x2y2, b2_x2y2 = b1_xy+b1_wh_half, b2_xy+b2_wh_half
    # 求出预测框和真实框的IOU
    intersect_min = torch.max(b1_x1y1, b2_x1y1)
    intersect_max = torch.min(b1_x2y2, b2_x2y2)
    intersect_wh = torch.max(intersect_max - intersect_min, torch.zeros_like(intersect_max))
    intersect_area = intersect_wh[:, 0] * intersect_wh[:, 1]
    b1_area = b1_wh[:, 0] * b1_wh[:, 1]
    b2_area = b2_wh[:, 0] * b2_wh[:, 1]
    union_area = b1_area + b2_area - intersect_area
    iou = intersect_area / union_area
    return iou


def box_ciou(b1, b2):
    # 求出预测框和真实框的左上角右下角
    b1_xy, b2_xy = b1[:, :2], b2[:, :2]
    b1_wh, b2_wh = b1[:, 2:4], b2[:, 2:4]
    b1_wh_half, b2_wh_half = b1_wh/2, b2_wh/2
    b1_x1y1, b2_x1y1 = b1_xy-b1_wh_half, b2_xy-b2_wh_half
    b1_x2y2, b2_x2y2 = b1_xy+b1_wh_half, b2_xy+b2_wh_half
    # 求出预测框和真实框的IOU
    intersect_min = torch.max(b1_x1y1, b2_x1y1)
    intersect_max = torch.min(b1_x2y2, b2_x2y2)
    intersect_wh = torch.max(intersect_max - intersect_min, torch.zeros_like(intersect_max))
    intersect_area = intersect_wh[:, 0] * intersect_wh[:, 1]
    b1_area = b1_wh[:, 0] * b1_wh[:, 1]
    b2_area = b2_wh[:, 0] * b2_wh[:, 1]
    union_area = b1_area + b2_area - intersect_area
    iou = intersect_area / union_area
    # 计算中心的差距
    center_distance = torch.sum((b1_xy - b2_xy)**2, dim=1)
    # 找到包裹两个框的最小框的左上角和右下角
    enclose_min = torch.min(b1_x1y1, b2_x1y1)
    enclose_max = torch.max(b1_x2y2, b2_x2y2)
    enclose_wh = torch.max(enclose_max - enclose_min, torch.zeros_like(intersect_max))
    # 计算对角线距离
    enclose_diagonal = torch.sum(enclose_wh**2, dim=1)
    ciou = iou - center_distance / enclose_diagonal
    v = (4 / (math.pi ** 2)) * (torch.atan(b1_wh[:, 0]/b1_wh[:, 1]) - torch.atan(b2_wh[:, 0]/b2_wh[:, 1]))**2
    alpha = v / (1 - iou + v)
    return ciou - alpha * v